-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Streaming inference #1275
Streaming inference #1275
Conversation
Oh yes, also pending is changes to audio.py to make it compute features frame-by-frame instead of on the whole file at once. (That, or figure out why those two modes have different outputs and fix it.) |
This API should be enough to implement the Web Speech API with some limitations (ordered by decreasing difficulty to solve):
Interim results are easy to solve, although expensive (latency) with our current setup. We can just add a "resultSoFar()" method that decodes the currently accumulated logits. A streaming implementation of the decoder would make this less costly.
|
6b118d9
to
ba14779
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Took a look, seems pretty good. But didn't go through deepspeech.cc yet. I'll tackle that when it's not a WIP
DeepSpeech.py
Outdated
@@ -1681,28 +1673,109 @@ def end(self, session): | |||
' or removing the contents of {0}.'.format(FLAGS.checkpoint_dir)) | |||
sys.exit(1) | |||
|
|||
def create_inference_graph(batch_size=None, use_new_decoder=False): | |||
def create_inference_step_graph(batch_x, seq_length, fw_cell, previous_state, batch_size=1, n_steps=16): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we guarantee that the graph defined here is always in sync with that of the method BiRNN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. Suggestions welcome. I'm still thinking about how to solve that without complicating the code more than it's worth it.
DeepSpeech.py
Outdated
h3 = variable_on_worker_level('h3', [n_hidden_2, n_hidden_3], tf.random_normal_initializer(stddev=FLAGS.h3_stddev)) | ||
layer_3 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_2, h3), b3)), FLAGS.relu_clip) | ||
|
||
# Now we create the forward and backward LSTM units. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment now incorrect
DeepSpeech.py
Outdated
layer_3 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_2, h3), b3)), FLAGS.relu_clip) | ||
|
||
# Now we create the forward and backward LSTM units. | ||
# Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment now incorrect
DeepSpeech.py
Outdated
# Output shape: [n_steps, batch_size, n_hidden_6] | ||
return layer_6, output_state | ||
|
||
def create_inference_graph(batch_size=None, n_steps=16, use_new_decoder=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How will we make sure graph elements, for example use of a BasicLSTMCell, stay consistent across this method and BiRNN?
native_client/client.cc
Outdated
@@ -26,59 +26,39 @@ | |||
|
|||
using namespace DeepSpeech; | |||
|
|||
struct ds_result { | |||
typedef struct { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're modernizing a bit why not use a class instead of a struct?
Also, why not user a std::string? (If this string is going over the ABI boundary at some point, ignore this statement.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A class is just a struct where members are private by default. I don't see how either is more modern, but I made them classes anyway.
I don't use an std::string here because that makes it harder to handle the nullptr return case of stt.
|
||
return res; | ||
} | ||
|
||
struct ds_audio_buffer { | ||
typedef struct { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same statements on modernizing
native_client/deepspeech.h
Outdated
@@ -8,24 +8,14 @@ | |||
namespace DeepSpeech | |||
{ | |||
|
|||
class Private; | |||
struct Private; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why switch from class to struct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted.
native_client/deepspeech.h
Outdated
class Private; | ||
struct Private; | ||
|
||
struct StreamingState; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use struct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted.
native_client/deepspeech.h
Outdated
class Private; | ||
struct Private; | ||
|
||
struct StreamingState; | ||
|
||
class Model { | ||
private: | ||
Private* mPriv; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Classes that pass over the ABI boundary should have a non-inline virtual destructor as far as I know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you're talking about Model
, not Private
, as that's an opaque pointer. I made Model
's destructor virtual.
28988b8
to
b3f86da
Compare
Tests passing, modulo AOT and production tests, which is expected. |
DeepSpeech.py
Outdated
initializer_nodes='', | ||
variable_names_blacklist='previous_state_c,previous_state_h') | ||
|
||
# Load and export as string |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we want to export it as pb and pbtxt ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to remove that but forgot. It's there so I can easily load an exported graph into Tensorboard to visualize it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. Do we need this by default ? I fear it's going to just waste disk space and time for most of the usecase. Would an extra CLI flag defaulting to false be a good deal ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make it a standalone tool in bin/graph_binary_to_text.py
native_client/BUILD
Outdated
### => Trying to be more fine-grained | ||
### Obtained by trial/error process ... | ||
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
]), | ||
includes = ["kenlm", "boost_locale"], | ||
includes = ["kenlm", "boost_locale", "c_speech_features", "kiss_fft130"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about the copts
tuning to disable AVX
, AVX2
and FMA
for performances reasons?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I'll try to get bazel to do the right thing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the reason why it was made is also not valid anymore? Would be useful that we re-evaluate that :)
] + if_native_model([ | ||
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", | ||
]) | ||
+ if_cuda([ | ||
])+ if_cuda([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm (happily) surprised you could remove the Slice GPU specific, but it might be a side-effect of the graph change itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably. There's no slice ops in the streaming graph.
native_client/BUILD
Outdated
### => Trying to be more fine-grained | ||
### Obtained by trial/error process ... | ||
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph. | ||
### CPU only build libdeepspeech.so from 63M to 36M |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Update on size impact ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will check.
@@ -53,7 +53,7 @@ class KenLMBeamScorer : public tensorflow::ctc::BaseBeamScorer<KenLMBeamState> { | |||
// ExpandState is called when expanding a beam to one of its children. | |||
// Called at most once per child beam. In the simplest case, no state | |||
// expansion is done. | |||
void ExpandState(const KenLMBeamState& from_state, int from_label, | |||
void ExpandState(const KenLMBeamState& from_state, int /*from_label*/, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented and not the rest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because that's the only parameter that isn't used. This was done to silence a warning.
native_client/beam_search.h
Outdated
#endif /* BEAM_SEARCH_H */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: blank change because of line ending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was on purpose to silence a warning about no line end at the end of the file
native_client/client.cc
Outdated
@@ -26,59 +26,39 @@ | |||
|
|||
using namespace DeepSpeech; | |||
|
|||
struct ds_result { | |||
typedef struct { | |||
char* string; | |||
double cpu_time_overall; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For comparison purpose, I'd like we keep the three measures :-)
util/audio.py
Outdated
|
||
try: | ||
from deepspeech.utils import audioToInputVector | ||
from deepspeech import audioToInputVector | ||
except ImportError: | ||
import numpy as np | ||
from python_speech_features import mfcc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we should update the warning message, to make it more clear:
DeepSpeech Python bindings could not be loaded, resorting to slower code to compute the audio vectors.
Refer to README.md for instructions on how to install (or build) deepspeech python bindings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, will do.
native_client/deepspeech.h
Outdated
* to {@link feedAudioContent()} and {@link finishStream()}. | ||
* | ||
* @param aPreAllocFrames Number of timestep frames to reserve. One timestep | ||
* is equivalent to two window lenghts (50ms), so |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the fact that "one timestep is equivalent to two window lengths" is dependant on parameters set during the training ? We should mention it if it's the case. (and there's a typo in lenghts
:))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's dependent on the training code, yes. Feature computation for training and inference need to be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's worth mentionning it?
native_client/deepspeech.h
Outdated
* @return A context pointer that represents the streaming state. | ||
*/ | ||
StreamingState* setupStream(unsigned int aPreAllocFrames = 150, | ||
unsigned int aSampleRate = 16000); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The more I think about that, the more I'm wondering if we should not have some graph metadata somehow to rely on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aPreAllocFrames is just an optimization knob, so not tied to the graph. The sample rate is kind of tied, in that it should be larger than or equal to the sample rate of the training data. (This requirement might go away depending on the result of our band-pass experiments).
native_client/deepspeech.cc
Outdated
const int SAMPLE_RATE = 16000; | ||
|
||
//TODO: infer n_steps from model | ||
const int N_STEPS_PER_BATCH = 16; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those three TODO
about using data from the model, seems related to my previous comment about relying on some graph metadata somehow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, n_steps and batch_size should definitely be obtained from the model directly. I think they're both easy to do, FWIW, just haven't gotten to it yet.
native_client/deepspeech.cc
Outdated
|
||
using namespace tensorflow; | ||
using tensorflow::ctc::CTCBeamSearchDecoder; | ||
using tensorflow::ctc::CTCDecoder; | ||
|
||
namespace DeepSpeech { | ||
|
||
class StreamingState { | ||
public: | ||
std::vector<float> accumulated_logits; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float
, is there any win/loss compared to double
for example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model outputs a float32 tensor, so there's no gain.
native_client/deepspeech.cc
Outdated
|
||
using namespace tensorflow; | ||
using tensorflow::ctc::CTCBeamSearchDecoder; | ||
using tensorflow::ctc::CTCDecoder; | ||
|
||
namespace DeepSpeech { | ||
|
||
class StreamingState { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As much as I can see, this class is completely defined here. In deepspeech.h
we only have a class StreamingState;
statement. Therefore, I think it might be useful that this is being properly documented, since the StreamingState
is exposed to the public.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an opaque pointer, so the layout is not exposed to API consumers. Is the documentation in setupStream/feedAudioContent/finishStream not enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have a point here. The other docs are good, I missed the fact that since it's not in the header it's not publicly exposed. We got some people on Discourse asking to play with the logits, if I remember correctly. Maybe there's room to expose some of that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather expose too little at first, than too much and be forced to maintain the "too much" part.
native_client/deepspeech.cc
Outdated
* | ||
* @param n_frames Number of timesteps to deal with | ||
* @param logits Matrix of logits, of dimensions: | ||
* [n_frames][batch_size][num_classes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states it is a matrix of [n_frames][batch_size][num_classes]
dimensions, but the prototype is std::vector<float>& logits
which seems like a simple vector. Am I missing something?
Also, there is no n_frames
as well :'(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I changed this from bare pointers to std::vector and forgot to update the comments.
* @param mfcc batch input data | ||
* @param n_frames number of timesteps in the data | ||
* | ||
* @param[out] output_logits Should be large enough to fit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should this be @return
instead of @param[out]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's an outparam, we append the results of the inference to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, hence the void.
ad47c78
to
a53b64c
Compare
Rebased on top of master and included changes made recently to the decoder. I think this is good to merge (after a review pass), but before releasing we should deal with some infrastructure bits included here for development:
|
25a189a
to
6fd09bf
Compare
9d26197
to
eb6f5f6
Compare
eb6f5f6
to
7b87336
Compare
I'm going to close this PR and open a new one for the second round of reviews since this one is so large it takes almost a minute to load on my computer. |
Is there a new PR yet? --> found it, it is probably the 7b87336 mentioned just above. |
This work was merged already |
This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs. |
This is the WIP of streaming inference support. We need to decide on a unidirectional architecture to use before we can merge this PR. Some pending stuff is better documentation of the streaming implementation, and actually testing the AOT model code paths.
At this point I'd appreciate a review of the API changes in deepspeech.h (see the clients and DeepSpeech::Model::stt for example usage).